#include "opencl_source_map.hpp" 
namespace MNN { 
const char* matmul = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) ""if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { ""return; ""}\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void matmul(GLOBAL_SIZE_2_DIMS __read_only image2d_t input_a,\n"
" __read_only image2d_t input_b,\n"
" #ifdef BIAS\n"
" __read_only image2d_t input_c,\n"
" #endif\n"
" __write_only image2d_t output_c,__private const int channels,\n"
" __private const int channel_blocks) {\n"
" const int width_blocks_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_idx);\n"
" FLOAT4 a;\n"
" FLOAT4 b0=0,b1=0,b2=0,b3=0;\n"
" #ifdef BIAS\n"
" FLOAT4 temp=RI_F(input_c,SAMPLER,(int2)(width_blocks_idx,0));\n"
" FLOAT result0=temp.x;\n"
" FLOAT result1=temp.y;\n"
" FLOAT result2=temp.z;\n"
" FLOAT result3=temp.w;\n"
" #else\n"
" FLOAT result0=0;\n"
" FLOAT result1=0;\n"
" FLOAT result2=0;\n"
" FLOAT result3=0;\n"
" #endif\n"
" for (short pos=0; pos<channel_blocks; pos += 1) {\n"
" a=RI_F(input_a,SAMPLER,(int2)(pos,height_idx));\n"
" short remain=(pos+1)*4-channels;\n"
" b0=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,pos*4));\n"
" b1=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,pos*4+1));\n"
" b2=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,pos*4+2));\n"
" b3=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,pos*4+3));\n"
" if (remain == 3) {\n"
" b1=0;\n"
" b2=0;\n"
" b3=0;\n"
" } else if (remain == 2) {\n"
" b2=0;\n"
" b3=0;\n"
" } else if (remain == 1) {\n"
" b3=0;\n"
" }\n"
" FLOAT4 btmp0=(FLOAT4)(b0.s0,b1.s0,b2.s0,b3.s0);\n"
" FLOAT4 btmp1=(FLOAT4)(b0.s1,b1.s1,b2.s1,b3.s1);\n"
" FLOAT4 btmp2=(FLOAT4)(b0.s2,b1.s2,b2.s2,b3.s2);\n"
" FLOAT4 btmp3=(FLOAT4)(b0.s3,b1.s3,b2.s3,b3.s3);\n"
" result0 += dot(a,btmp0);\n"
" result1 += dot(a,btmp1);\n"
" result2 += dot(a,btmp2);\n"
" result3 += dot(a,btmp3);\n"
" }\n"
" WI_F(output_c,(int2)(width_blocks_idx,height_idx),(FLOAT4)(result0,result1,result2,result3));\n"
"}\n"
"__kernel void matmul_transB(GLOBAL_SIZE_2_DIMS __read_only image2d_t input_a,\n"
" __read_only image2d_t input_b,\n"
" #ifdef BIAS\n"
" __read_only image2d_t input_c,\n"
" #endif\n"
" __write_only image2d_t output_c,__private const int channels,\n"
" __private const int channel_blocks) {\n"
" const int width_blocks_idx=get_global_id(0);\n"
" const int height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_idx);\n"
" FLOAT4 a;\n"
" FLOAT4 b0=0,b1=0,b2=0,b3=0;\n"
" #ifdef BIAS\n"
" FLOAT4 temp=RI_F(input_c,SAMPLER,(int2)(width_blocks_idx,0));\n"
" FLOAT result0=temp.x;\n"
" FLOAT result1=temp.y;\n"
" FLOAT result2=temp.z;\n"
" FLOAT result3=temp.w;\n"
" #else\n"
" FLOAT result0=0;\n"
" FLOAT result1=0;\n"
" FLOAT result2=0;\n"
" FLOAT result3=0;\n"
" #endif\n"
" for (short pos=0; pos<channel_blocks; pos += 1) {\n"
" a=RI_F(input_a,SAMPLER,(int2)(pos,height_idx));\n"
" short remain=(pos+1)*4-channels;\n"
" b0=RI_F(input_b,SAMPLER,(int2)(pos,width_blocks_idx*4));\n"
" b1=RI_F(input_b,SAMPLER,(int2)(pos,width_blocks_idx*4+1));\n"
" b2=RI_F(input_b,SAMPLER,(int2)(pos,width_blocks_idx*4+2));\n"
" b3=RI_F(input_b,SAMPLER,(int2)(pos,width_blocks_idx*4+3));\n"
" if (remain == 3) {\n"
" a.y=0;\n"
" a.z=0;\n"
" a.w=0;\n"
" } else if (remain == 2) {\n"
" a.z=0;\n"
" a.w=0;\n"
" } else if (remain == 1) {\n"
" a.w=0;\n"
" }\n"
" result0 += dot(a,b0);\n"
" result1 += dot(a,b1);\n"
" result2 += dot(a,b2);\n"
" result3 += dot(a,b3);\n"
" }\n"
" WI_F(output_c,(int2)(width_blocks_idx,height_idx),(FLOAT4)(result0,result1,result2,result3));\n"
"}\n"
" __kernel void matmul_transA(GLOBAL_SIZE_2_DIMS __read_only image2d_t input_a,\n"
" __read_only image2d_t input_b,\n"
" #ifdef BIAS\n"
" __read_only image2d_t input_c,\n"
" #endif\n"
" __write_only image2d_t output_c,\n"
" __private const int channels,\n"
" __private const int channel_blocks,\n"
" __private const int height) {\n"
" const int width_blocks_idx=get_global_id(0);\n"
" const int height_blocks_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_blocks_idx);\n"
" FLOAT4 v_zero=(FLOAT4)((FLOAT)0.0);\n"
" #ifdef BIAS\n"
" FLOAT4 result0=RI_F(input_c,SAMPLER,(int2)(width_blocks_idx,0));\n"
" FLOAT4 result1=result0;\n"
" FLOAT4 result2=result0;\n"
" FLOAT4 result3=result0;\n"
" #else\n"
" FLOAT4 result0=0;\n"
" FLOAT4 result1=0;\n"
" FLOAT4 result2=0;\n"
" FLOAT4 result3=0;\n"
" #endif\n"
" \n"
" for (short pos=0; pos<channel_blocks; pos += 1) {\n"
" FLOAT4 a0=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos));\n"
" FLOAT4 a1=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+1));\n"
" FLOAT4 a2=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+2));\n"
" FLOAT4 a3=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+3));\n"
" FLOAT4 b0=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,4*pos));\n"
" FLOAT4 b1=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,4*pos+1));\n"
" FLOAT4 b2=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,4*pos+2));\n"
" FLOAT4 b3=RI_F(input_b,SAMPLER,(int2)(width_blocks_idx,4*pos+3));\n"
" \n"
" short remain=(pos+1)*4-channels;\n"
" a3=((remain >= 1) ? v_zero : a3);\n"
" a2=((remain >= 2) ? v_zero : a2);\n"
" a1=((remain >= 3) ? v_zero : a1);\n"
" FLOAT4 a0_trans=(FLOAT4)(a0.x,a1.x,a2.x,a3.x);\n"
" FLOAT4 a1_trans=(FLOAT4)(a0.y,a1.y,a2.y,a3.y);\n"
" FLOAT4 a2_trans=(FLOAT4)(a0.z,a1.z,a2.z,a3.z);\n"
" FLOAT4 a3_trans=(FLOAT4)(a0.w,a1.w,a2.w,a3.w);\n"
" \n"
" FLOAT4 b0_trans=(FLOAT4)(b0.x,b1.x,b2.x,b3.x);\n"
" FLOAT4 b1_trans=(FLOAT4)(b0.y,b1.y,b2.y,b3.y);\n"
" FLOAT4 b2_trans=(FLOAT4)(b0.z,b1.z,b2.z,b3.z);\n"
" FLOAT4 b3_trans=(FLOAT4)(b0.w,b1.w,b2.w,b3.w);\n"
" //matmul\n"
" result0.x += dot(a0_trans,b0_trans);\n"
" result0.y += dot(a0_trans,b1_trans);\n"
" result0.z += dot(a0_trans,b2_trans);\n"
" result0.w += dot(a0_trans,b3_trans);\n"
" \n"
" result1.x += dot(a1_trans,b0_trans);\n"
" result1.y += dot(a1_trans,b1_trans);\n"
" result1.z += dot(a1_trans,b2_trans);\n"
" result1.w += dot(a1_trans,b3_trans);\n"
" \n"
" result2.x += dot(a2_trans,b0_trans);\n"
" result2.y += dot(a2_trans,b1_trans);\n"
" result2.z += dot(a2_trans,b2_trans);\n"
" result2.w += dot(a2_trans,b3_trans);\n"
" \n"
" result3.x += dot(a3_trans,b0_trans);\n"
" result3.y += dot(a3_trans,b1_trans);\n"
" result3.z += dot(a3_trans,b2_trans);\n"
" result3.w += dot(a3_trans,b3_trans);\n"
" }\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx),result0);\n"
" if(4*height_blocks_idx+1 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+1),result1);\n"
" if(4*height_blocks_idx+2 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+2),result2);\n"
" if(4*height_blocks_idx+3 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+3),result3);\n"
"}\n"
"__kernel void matmul_transA_transB(GLOBAL_SIZE_2_DIMS __read_only image2d_t input_a,\n"
" __read_only image2d_t input_b,\n"
" #ifdef BIAS\n"
" __read_only image2d_t input_c,\n"
" #endif\n"
" __write_only image2d_t output_c,\n"
" __private const int channels,\n"
" __private const int channel_blocks,\n"
" __private const int height) {\n"
" const int width_blocks_idx=get_global_id(0);\n"
" const int height_blocks_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_blocks_idx);\n"
" FLOAT4 v_zero=(FLOAT4)((FLOAT)0.0);\n"
" #ifdef BIAS\n"
" FLOAT4 result0=RI_F(input_c,SAMPLER,(int2)(width_blocks_idx,0));\n"
" FLOAT4 result1=result0;\n"
" FLOAT4 result2=result0;\n"
" FLOAT4 result3=result0;\n"
" #else\n"
" FLOAT4 result0=0;\n"
" FLOAT4 result1=0;\n"
" FLOAT4 result2=0;\n"
" FLOAT4 result3=0;\n"
" #endif\n"
" for (short pos=0; pos<channel_blocks; pos += 1) {\n"
" FLOAT4 a0=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos));\n"
" FLOAT4 a1=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+1));\n"
" FLOAT4 a2=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+2));\n"
" FLOAT4 a3=RI_F(input_a,SAMPLER,(int2)(height_blocks_idx,4*pos+3));\n"
" FLOAT4 b0=RI_F(input_b,SAMPLER,(int2)(pos,4*width_blocks_idx));\n"
" FLOAT4 b1=RI_F(input_b,SAMPLER,(int2)(pos,4*width_blocks_idx+1));\n"
" FLOAT4 b2=RI_F(input_b,SAMPLER,(int2)(pos,4*width_blocks_idx+2));\n"
" FLOAT4 b3=RI_F(input_b,SAMPLER,(int2)(pos,4*width_blocks_idx+3));\n"
" \n"
" short remain=(pos+1)*4-channels;\n"
" a3=((remain >= 1) ? v_zero : a3);\n"
" a2=((remain >= 2) ? v_zero : a2);\n"
" a1=((remain >= 3) ? v_zero : a1);\n"
" FLOAT4 a0_trans=(FLOAT4)(a0.x,a1.x,a2.x,a3.x);\n"
" FLOAT4 a1_trans=(FLOAT4)(a0.y,a1.y,a2.y,a3.y);\n"
" FLOAT4 a2_trans=(FLOAT4)(a0.z,a1.z,a2.z,a3.z);\n"
" FLOAT4 a3_trans=(FLOAT4)(a0.w,a1.w,a2.w,a3.w);\n"
" //matmul\n"
" result0.x += dot(a0_trans,b0);\n"
" result0.y += dot(a0_trans,b1);\n"
" result0.z += dot(a0_trans,b2);\n"
" result0.w += dot(a0_trans,b3);\n"
" \n"
" result1.x += dot(a1_trans,b0);\n"
" result1.y += dot(a1_trans,b1);\n"
" result1.z += dot(a1_trans,b2);\n"
" result1.w += dot(a1_trans,b3);\n"
" \n"
" result2.x += dot(a2_trans,b0);\n"
" result2.y += dot(a2_trans,b1);\n"
" result2.z += dot(a2_trans,b2);\n"
" result2.w += dot(a2_trans,b3);\n"
" \n"
" result3.x += dot(a3_trans,b0);\n"
" result3.y += dot(a3_trans,b1);\n"
" result3.z += dot(a3_trans,b2);\n"
" result3.w += dot(a3_trans,b3);\n"
" }\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx),result0);\n"
" if(4*height_blocks_idx+1 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+1),result1);\n"
" if(4*height_blocks_idx+2 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+2),result2);\n"
" if(4*height_blocks_idx+3 >= height) return;\n"
" WI_F(output_c,(int2)(width_blocks_idx,4*height_blocks_idx+3),result3);\n"
"}\n"
;
}
